import torch
from torch import nn
from torch.optim.sgd import SGD

from nonneg_sgd import *
from flow import *
from util import *
from evaluation import *

def fun_opt_lamb(lamb, args):
    '''
        Auxiliary function for line search for lambda
        in min-div algorithm.
    '''
    G, index, valid_flows, A, b, n_iter, lr, nonneg = args
    reg_vec = lamb * np.ones((A.shape[1],1))
    x_init = initialize_flows(A.shape[1])
    x_prior = initialize_flows(A.shape[1], zeros=True)
    tk = Tikhonov(A, b, n_iter, lr, nonneg)
    tk.train(reg_vec, x_init, x_prior, verbose=False)
    pred = get_dict_flows_from_tensor(index, tk.x, valid_flows)
    loss = rmse(G, valid_flows, pred, {})

    return loss

def get_prior(G, priors, train, index):
    '''
    '''

    prior_flows = np.zeros((G.number_of_edges()-len(train), 1))

    for e in G.edges():
        if e not in train:
            i = index[e]
            prior_flows[i,0] = priors[e]

    use_cuda = torch.cuda.is_available()

    if use_cuda:
        prior_flows = torch.cuda.FloatTensor(prior_flows)
    else:
        prior_flows = torch.FloatTensor(prior_flows)

    return prior_flows

class Tikhonov(nn.Module):
    '''
        Solves Tikhonov regularization problem.
        x* = min_x ||Ax-b||_2^2 + lamb^2||x||_2^2
    '''
    def __init__(self, A, b, n_iter, learning_rate, nonneg=False):
        super(Tikhonov, self).__init__()
        
        self.use_cuda = torch.cuda.is_available()
        
        if self.use_cuda:
            self.cuda()
        
        self.A = sparse_tensor_from_coo_matrix(A)
        
        if self.use_cuda:
            self.b = torch.cuda.FloatTensor([b]).T
        else:
            self.b = torch.FloatTensor([b]).T
        
        self.n_iter = n_iter
        self.learning_rate = learning_rate
        self.nonneg = nonneg
             
    def forward(self, x, reg_vec, x_prior):
        '''
            returns loss for given x
        '''
        return torch.square(torch.norm(torch.sparse.mm(self.A, x)-self.b,2)) + torch.sum(torch.mul(x-x_prior, torch.mul(reg_vec, x-x_prior)))
    
    def train(self, reg_vec, x_init, x_prior, verbose=False):
        '''
            Computes x* given lambda (lamb) parameter
        '''
        losses = []
        self.x =  x_init
        if self.use_cuda:
            t_reg_vec = torch.cuda.FloatTensor(reg_vec)
        else:
            t_reg_vec = torch.FloatTensor(reg_vec)

        if self.nonneg:            
            param_list = [{'params': self.x, 'lr': self.learning_rate}]
            
            self.optimizer = NONNEGSGD([self.x], lr=self.learning_rate)
        else:
            self.optimizer = optim.Adam([self.x], lr=self.learning_rate)
                
        for epoch in range(self.n_iter):
            self.optimizer.zero_grad()
            loss = self.forward(self.x, t_reg_vec, x_prior)
            loss.backward()   
            self.optimizer.step()
            
            if epoch % 1000 == 0 and verbose is True:
                print("epoch: ", epoch, " loss = ", loss.item())
            
            losses.append(loss.item())
